{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 {class}`~tvm.relay.dataflow_pattern.DFPatternCallback`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "class DFPatternCallback;\n",
    "/*!\n",
    " * \\brief Base type of all dataflow pattern callbacks.\n",
    " * \\sa DFPatternCallback\n",
    " */\n",
    "class DFPatternCallbackNode : public Object {\n",
    " public:\n",
    "  /*! \\brief Pattern this callback matches */\n",
    "  DFPattern pattern;\n",
    "  /*! \\brief Function to call when finding a matched expression */\n",
    "  PackedFunc function;\n",
    "  /*! \\brief Require InferType to be run before the callback */\n",
    "  bool require_type;\n",
    "  /*! \\brief Run the callback only once */\n",
    "  bool rewrite_once;\n",
    "\n",
    "  void VisitAttrs(tvm::AttrVisitor* v) {\n",
    "    v->Visit(\"pattern\", &pattern);\n",
    "    v->Visit(\"require_type\", &require_type);\n",
    "    v->Visit(\"rewrite_once\", &rewrite_once);\n",
    "  }\n",
    "\n",
    "  static constexpr const char* _type_key = \"DFPatternCallbackNode\";\n",
    "  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);\n",
    "};\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`DFPatternCallbackNode` 类，它是所有数据流模式回调的基础类型。这个类继承自 `Object` 类。它包含以下成员变量：\n",
    "\n",
    "1. `DFPattern pattern`：匹配的模式。\n",
    "2. `PackedFunc function`：找到匹配表达式时要调用的函数。\n",
    "3.` bool require_type`：在回调之前是否需要运行 `InferType`。\n",
    "4. `bool rewrite_once`：是否只运行一次回调。\n",
    "\n",
    "此外，这个类还包含名为 `VisitAttrs` 的成员函数，用于访问这些属性。最后，它还定义了静态常量字符串 `_type_key`，用于表示这个类的类型，以及 `TVM_DECLARE_BASE_OBJECT_INFO` 宏，用于声明这个类的信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from testing import viz_expr # 可视化 relay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "parameters",
     "hidden-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mInit signature:\u001b[0m\n",
      "\u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataflow_pattern\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDFPatternCallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mrequire_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mrewrite_once\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m        \n",
      "\u001b[0;32mclass\u001b[0m \u001b[0mDFPatternCallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m\"\"\"A Callback for Pattern Rewriting.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    When rewrite is called on this DFPatternCallback, the backend will find matches for the\u001b[0m\n",
      "\u001b[0;34m    pattern, call the callback function, and replace the matched expression with whatever\u001b[0m\n",
      "\u001b[0;34m    the callback returns.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Users are expect to inherit from this class and provide a \"self.pattern\" to match\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Parameters\u001b[0m\n",
      "\u001b[0;34m    ----------\u001b[0m\n",
      "\u001b[0;34m    require_type: bool\u001b[0m\n",
      "\u001b[0;34m        Whether InferType is required to be run before the callback.\u001b[0m\n",
      "\u001b[0;34m    rewrite_once: bool\u001b[0m\n",
      "\u001b[0;34m        If True, run the callback only once.\u001b[0m\n",
      "\u001b[0;34m    \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequire_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrewrite_once\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpattern\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequire_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrequire_type\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrewrite_once\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrewrite_once\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mdef\u001b[0m \u001b[0mrewrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpr\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mExpr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mExpr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;34m\"\"\"\u001b[0m\n",
      "\u001b[0;34m        Rewrite expression with this callback\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        Parameters\u001b[0m\n",
      "\u001b[0;34m        ----------\u001b[0m\n",
      "\u001b[0;34m        expr : tvm.relay.Expr\u001b[0m\n",
      "\u001b[0;34m            The expression to rewrite.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        Returns\u001b[0m\n",
      "\u001b[0;34m        -------\u001b[0m\n",
      "\u001b[0;34m        result : tvm.relay.Expr\u001b[0m\n",
      "\u001b[0;34m            The Expression with matched subgraphs rewritten by the callbacks.\u001b[0m\n",
      "\u001b[0;34m        \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;32mreturn\u001b[0m \u001b[0mrewrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mdef\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpre\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mExpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpost\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mExpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode_map\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMap\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mExpr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;34m\"\"\"\u001b[0m\n",
      "\u001b[0;34m        Callback function to use when we found a match to the pattern\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        Parameters\u001b[0m\n",
      "\u001b[0;34m        ----------\u001b[0m\n",
      "\u001b[0;34m        pre : tvm.relay.Expr\u001b[0m\n",
      "\u001b[0;34m            The matching expression from the original graph.\u001b[0m\n",
      "\u001b[0;34m        post : tvm.relay.Expr\u001b[0m\n",
      "\u001b[0;34m            The matching expression with rewritten inputs\u001b[0m\n",
      "\u001b[0;34m        node_map : tvm.ir.container.Map[DFPattern, List[Expr]]\u001b[0m\n",
      "\u001b[0;34m            The map between patterns and matched expressions\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        Returns\u001b[0m\n",
      "\u001b[0;34m        -------\u001b[0m\n",
      "\u001b[0;34m        result : tvm.relay.Expr\u001b[0m\n",
      "\u001b[0;34m            The Expression with matched subgraph rewritten by the callback\u001b[0m\n",
      "\u001b[0;34m        \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m           /media/pc/data/lxw/ai/tvm/python/tvm/relay/dataflow_pattern/__init__.py\n",
      "\u001b[0;31mType:\u001b[0m           type\n",
      "\u001b[0;31mSubclasses:\u001b[0m     LayerNormRewrite, DenseReshapeBiasGeluRewrite, ResNetV1Rewrite, LegalizeQnnOpForDnnl, MulticlassNMSRewrite, PostNMSTopKRewrite, ScatterRewrite, qdistilbert_rewrite, remove_empty_pad_callback, simplify_qnn_concat_in_func, ..."
     ]
    }
   ],
   "source": [
    "tvm.relay.dataflow_pattern.DFPatternCallback??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当在 `DFPatternCallback` 上调用 `rewrite` 时，后端将找到与模式匹配的部分，调用回调函数，并将匹配的表达式替换为回调返回的内容。\n",
    "\n",
    "用户需要继承这个类并提供 `\"self.pattern\"` 来匹配。\n",
    "\n",
    "参数：\n",
    "- `require_type`: `bool` 类型，表示是否需要在回调之前运行 `InferType`。\n",
    "- `rewrite_once`: `bool` 类型，如果为 `True`，则只运行一次回调。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{func}`~tvm.relay.dataflow_pattern.DFPatternCallback.callback` 函数，用于在找到模式匹配时使用。\n",
    "\n",
    "参数：\n",
    "- `pre`: `tvm.relay.Expr` 类型，表示原始图中的匹配表达式。\n",
    "- `post`: `tvm.relay.Expr` 类型，表示重写输入后的匹配表达式。\n",
    "- `node_map`: `tvm.ir.container.Map[DFPattern, List[Expr]]` 类型，表示模式和匹配表达式之间的映射关系。\n",
    "\n",
    "返回值：\n",
    "- `result`: `tvm.relay.Expr` 类型，表示通过回调重写匹配子图的表达式。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
